[1]. Load Dependencies¶

In [2]:
# Load all necessary libraries and dependencies
import os
import cv2
import random
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

from tqdm import tqdm

# Load Tensorflow layer types
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, Cropping2D
from tensorflow.keras.models import Model

[2]. Create Datasets¶

In [3]:
# Create face detector object from the Open Computer Vision library
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')
In [4]:
# Load in complete dataset of faces
train_celeb, test_celeb = tfds.load("celeb_a",
                                    split = ["train", "test"],
                                    shuffle_files = False,
                                    data_dir = '/kaggle/input/tfds-celeba-dataset',
                                    download = False)
In [5]:
# Function to apply Gaussian blur to faces
def blur_faces(image, face, strength):
    
    # Blur the detected face
    for (x, y, w, h) in face:
        
        # Apply blur to boundaries
        face_roi = image[y:y+h, x:x+w]
        blurred_face = cv2.GaussianBlur(face_roi, (strength, strength), 30)
        
        # Attach blurred face
        image[y:y+h, x:x+w] = blurred_face
    
    return image

# Function to pixelate faces
def pixelate_faces(image, face, strength):
    
    # Pixelate the detected face
    for (x, y, w, h) in face:
        
        # Apply blur to boundaries
        face_roi = image[y:y+h, x:x+w]
        face_roi = cv2.resize(face_roi, (w // strength, h // strength))
        face_roi = cv2.resize(face_roi, (w, h), interpolation = cv2.INTER_NEAREST)
        
        # Attach pixelated face
        image[y:y+h, x:x+w] = face_roi
    
    return image

# Function to apply motion blur faces
def motion_blur_faces(image, face, strength):
    
    # Apply motion blur to the detected face
    for (x, y, w, h) in face:
        
        # Apply blur to boundaries
        face_roi = image[y:y+h, x:x+w]
        kernel_motion_blur = np.zeros((strength, strength))
        kernel_motion_blur[int((strength - 1)/2), :] = np.ones(strength)
        kernel_motion_blur = kernel_motion_blur / strength
        face_roi = cv2.filter2D(face_roi, -1, kernel_motion_blur)
        
        # Attach blurred face
        image[y:y+h, x:x+w] = face_roi
    
    return image
In [10]:
# Create a dataset portion
def create_dataset(dataset, image_type, quarter, test_fold):
    img_lst = []
    total_images = len(dataset)
    
    # Select first & last indices
    frst_idx = (total_images // 4) * (quarter - 1)
    last_idx = frst_idx + (total_images // 4)
    
    if not test_fold:
        total_images = last_idx
    
    for i, image_dict in tqdm(enumerate(tfds.as_numpy(dataset)), total = total_images, desc = "Building Dataset"):
        
        # Skip to appropriate indices
        if not test_fold:
            # Find appropriate fold of the data
            if i < frst_idx:
                continue
            if i == last_idx:
                break

        # Filter out lesser quality images
        if image_dict['attributes']['Blurry'] or image_dict['attributes']['Eyeglasses']:
            continue
        
        image = image_dict['image']
        
        # Calculate the center of the image
        center_y = image.shape[0] // 2
        center_x = image.shape[1] // 2

        # Calculate the top and bottom bounds for cropping
        top = max(0, center_y - center_x)                    
        bottom = min(image.shape[0], center_y + center_x) 

        # Crop the region around the center
        cropped_img = image[top:bottom, :, :]
        
        # Detect the face
        gray = cv2.cvtColor(cropped_img, cv2.COLOR_BGR2GRAY)
        faces = face_cascade.detectMultiScale(gray, scaleFactor = 1.1, minNeighbors = 5, minSize = (30, 30))
        
        # Only select one face if more are found
        if len(faces) == 0:
            continue
        if len(faces) > 1:
            faces = [faces[0]]
        
        # Apply filter
        if image_type == 'Gauss':
            cropped_img = blur_faces(cropped_img, faces, 5)
        elif image_type == 'Pixel':
            cropped_img = pixelate_faces(cropped_img, faces, 4)
        elif image_type == 'Motion':
            cropped_img = motion_blur_faces(cropped_img, faces, 15)

        cropped_img = cropped_img.astype(np.float16)
        
        # Normalize pixels
        cropped_img /= 255.0
        img_lst.append(cropped_img)
        
    return img_lst


# Convert image dataset to TensorFlow dataset
def to_tf_dataset(images):
    dataset = tf.data.Dataset.from_tensor_slices(np.asarray(images))
    dataset = dataset.map(lambda image: {
        'image': image
    })
    return dataset

I manually ran this code to get all different iterations of all datasets loaded into the necessary format.

In [ ]:
target_images1 = create_dataset(train_celeb, "Original", 1, False)
target_images1 = to_tf_dataset(target_images1)
target_images1.save('/kaggle/working/target1', compression = 'GZIP')
del target_images1
Building Dataset: 100%|██████████| 40692/40692 [11:24<00:00, 59.43it/s]

[3]. Load in Train & Validation Data¶

In [3]:
# Load in target data
target1_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/target1/target1', compression = "GZIP")
target2_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/target2/target2', compression = "GZIP")
target3_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/target3/target3', compression = "GZIP")
target4_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/target4/target4', compression = "GZIP")

# Merge target datasets
target_df = target1_df.concatenate(target2_df)
target_df = target_df.concatenate(target3_df)
target_df = target_df.concatenate(target4_df)
In [4]:
# Load in Gaussian blurred data
blurry1_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/blurry1/blurry1', compression = "GZIP")
blurry2_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/blurry2/blurry2', compression = "GZIP")
blurry3_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/blurry3/blurry3', compression = "GZIP")
blurry4_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/blurry4/blurry4', compression = "GZIP")

# Merge Gaussian blurred datasets
blurry_df = blurry1_df.concatenate(blurry2_df)
blurry_df = blurry_df.concatenate(blurry3_df)
blurry_df = blurry_df.concatenate(blurry4_df)
In [5]:
# Load in pixelated data
pixels1_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/pixels1/pixels1', compression = "GZIP")
pixels2_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/pixels2/pixels2', compression = "GZIP")
pixels3_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/pixels3/pixels3', compression = "GZIP")
pixels4_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/pixels4/pixels4', compression = "GZIP")

# Merge pixelated datasets
pixels_df = pixels1_df.concatenate(pixels2_df)
pixels_df = pixels_df.concatenate(pixels3_df)
pixels_df = pixels_df.concatenate(pixels4_df)
In [6]:
# Load in motion blurred data
motion1_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/motion1/motion1', compression = "GZIP")
motion2_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/motion2/motion2', compression = "GZIP")
motion3_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/motion3/motion3', compression = "GZIP")
motion4_df = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/motion4/motion4', compression = "GZIP")

# Merge motion blurred datasets
motion_df = motion1_df.concatenate(motion2_df)
motion_df = motion_df.concatenate(motion3_df)
motion_df = motion_df.concatenate(motion4_df)
In [7]:
# Load in validation data
val_target = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/target_test/target_test', compression = "GZIP")
val_blurry = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/blurry_test/blurry_test', compression = "GZIP")
val_pixels = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/pixels_test/pixels_test', compression = "GZIP")
val_motion = tf.data.Dataset.load('/kaggle/input/celeba-faces-dataset/motion_test/motion_test', compression = "GZIP")
In [8]:
# Merge data for training
combined_blurry_df = tf.data.Dataset.zip((blurry_df, target_df))
combined_pixels_df = tf.data.Dataset.zip((pixels_df, target_df))
combined_motion_df = tf.data.Dataset.zip((motion_df, target_df))

# Merge data for validation
val_blurry_df = tf.data.Dataset.zip((val_blurry, val_target))
val_pixels_df = tf.data.Dataset.zip((val_pixels, val_target))
val_motion_df = tf.data.Dataset.zip((val_motion, val_target))
In [10]:
# Define pipelining function
def pair_images(element_a, element_b):
    """Given an blurred element drawn from the CelebA dataset (& its original version),
    this returns both images (training and target) together."""
    image = element_a['image']
    target = element_b['image']
    return image, target

# Initialize image-pairing operation for training data
blurry_pipe = combined_blurry_df.map(pair_images)
pixels_pipe = combined_pixels_df.map(pair_images)
motion_pipe = combined_motion_df.map(pair_images)

# Initialize image-pairing operation for validation data
blurry_validation = val_blurry_df.map(pair_images)
pixels_validation = val_pixels_df.map(pair_images)
motion_validation = val_motion_df.map(pair_images)

[4]. Plot Sample Images¶

In [18]:
# Define the titles for each row
titles = ['Original Image', 'Gaussian Blur', 'Pixelation', 'Motion Blur']

# Select 4 random images
selected_indices = random.sample(range(200), 4)

# Plot all images with titles
plt.figure(figsize = (9, 9))
for i, idx in enumerate(selected_indices):
    # Plot original images
    plt.subplot(4, 4, i + 1, frameon = True)
    image = blurry_pipe.skip(idx).take(1)
    image = np.expand_dims(next(iter(image))[1], axis = 0)
    plt.imshow((image[0] * 255).astype(int))
    plt.axis('off')
    if i == 0:
        # Add title
        plt.text(-0.1, 0.5, titles[0], fontsize = 10, ha = 'right', va = 'center',
                 rotation = 90, transform = plt.gca().transAxes)
        
    # Plot blurry images with titles
    plt.subplot(4, 4, i + 5, frameon = True)
    image = blurry_pipe.skip(idx).take(1)
    image = np.expand_dims(next(iter(image))[0], axis = 0)
    plt.imshow((image[0] * 255).astype(int))
    plt.axis('off')
    if i == 0:
        # Add title
        plt.text(-0.1, 0.5, titles[1], fontsize = 10, ha = 'right', va = 'center',
                 rotation = 90, transform = plt.gca().transAxes)
        
    # Plot pixelated images with titles
    plt.subplot(4, 4, i + 9, frameon = True)
    image = pixels_pipe.skip(idx).take(1)
    image = np.expand_dims(next(iter(image))[0], axis = 0)
    plt.imshow((image[0] * 255).astype(int))
    plt.axis('off')
    if i == 0:
        # Add title
        plt.text(-0.1, 0.5, titles[2], fontsize = 10, ha = 'right', va = 'center',
                 rotation = 90, transform = plt.gca().transAxes)
        
    # Plot motion blurred images with titles
    plt.subplot(4, 4, i + 13, frameon = True)
    image = motion_pipe.skip(idx).take(1)
    image = np.expand_dims(next(iter(image))[0], axis = 0)
    plt.imshow((image[0] * 255).astype(int))
    plt.axis('off')
    if i == 0:
        # Add title
        plt.text(-0.1, 0.5, titles[3], fontsize = 10, ha = 'right', va = 'center',
                 rotation = 90, transform = plt.gca().transAxes)

plt.tight_layout()
plt.show()

[5]. Train Autoencoder on Gaussian Blur Images (Easier Task)¶

In [31]:
def simple_autoencoder(input_shape):
    inputs = Input(shape = input_shape)
    x = Conv2D(32, (3, 3), activation = 'relu', padding = 'same')(inputs)
    x = MaxPooling2D((2, 2), padding = 'same')(x)
    x = Conv2D(32, (3, 3), activation = 'relu', padding = 'same')(x)
    encoded = MaxPooling2D((2, 2), padding = 'same')(x)

    # Decoder
    x = Conv2D(32, (3, 3), activation = 'relu', padding = 'same')(encoded)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(32, (3, 3), activation = 'relu', padding = 'same')(x)
    x = UpSampling2D((2, 2))(x)
    decoded = Conv2D(3, (3, 3), activation = 'sigmoid', padding = 'same')(x)
    
    cropped_decoded = Cropping2D(cropping=((1, 1), (1, 1)))(decoded)
    
    autoencoder = Model(inputs, cropped_decoded)
    autoencoder.compile(optimizer = 'adam', loss = 'mean_squared_error')
    return autoencoder
In [33]:
# Define input shape
input_shape = (178, 178, 3)

# Build the autoencoder
autoencoder = simple_autoencoder(input_shape)
print("Number of Parameters:", autoencoder.count_params())

# Train the autoencoder
autoencoder.fit(blurry_pipe.batch(16).prefetch(3),
                epochs = 3,
                shuffle = True,
                verbose = True)
Number of Parameters: 29507
Epoch 1/3
8610/8610 ━━━━━━━━━━━━━━━━━━━━ 272s 31ms/step - loss: 0.0034
Epoch 2/3
8610/8610 ━━━━━━━━━━━━━━━━━━━━ 213s 25ms/step - loss: 0.0011
Epoch 3/3
8610/8610 ━━━━━━━━━━━━━━━━━━━━ 212s 25ms/step - loss: 9.9701e-04
Out[33]:
<keras.src.callbacks.history.History at 0x7ae2127a6860>
In [34]:
indices = [0, 21, 24, 30,
           0, 21, 24, 30,
           0, 21, 24, 30]
fig, axes = plt.subplots(3, 4, figsize = (9,7), subplot_kw = {"xticks": [], "yticks": []})

for i, index in enumerate(indices):
    ax = axes.flat[i]
    
    if i == 0:
        ax.text(-0.1, 0.5, "Target Images", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
    elif i == 4:
        ax.text(-0.1, 0.5, "Gaussian Blur Images", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
    elif i == 8:
        ax.text(-0.1, 0.5, "Deblurred Images", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
    
    if i < 4:
        image_with_batch = blurry_pipe.skip(index).take(1)
        image_with_batch = np.expand_dims(next(iter(image_with_batch))[1], axis = 0)
        ax.imshow((image_with_batch[0] * 255).astype(int))
    elif i < 8:
        image_with_batch = blurry_pipe.skip(index).take(1)
        image_with_batch = np.expand_dims(next(iter(image_with_batch))[0], axis = 0)
        ax.imshow((image_with_batch[0] * 255).astype(int))
    else:   
        image_with_batch = blurry_pipe.skip(index).take(1)
        image_with_batch = np.expand_dims(next(iter(image_with_batch))[0], axis = 0)
        reconstructed_image = autoencoder.predict(image_with_batch)
        ax.imshow((reconstructed_image[0] * 255).astype(int))     
    
plt.tight_layout()
1/1 ━━━━━━━━━━━━━━━━━━━━ 1s 1s/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step

[6]. Train Autoencoder on Pixelated Images¶

In [1]:
model_comparison = dict()
In [30]:
def build_autoencoder(input_shape):
    inputs = Input(shape = input_shape)
    x = Conv2D(128, (3, 3), activation = 'relu', padding = 'same')(inputs)
    x = MaxPooling2D((2, 2), padding = 'same')(x)
    x = Conv2D(128, (3, 3), activation = 'relu', padding = 'same')(x)
    encoded = MaxPooling2D((2, 2), padding = 'same')(x)

    # Decoder
    x = Conv2D(128, (3, 3), activation = 'relu', padding = 'same')(encoded)
    x = UpSampling2D((2, 2))(x)
    x = Conv2D(128, (3, 3), activation = 'relu', padding = 'same')(x)
    x = UpSampling2D((2, 2))(x)
    decoded = Conv2D(3, (3, 3), activation = 'sigmoid', padding = 'same')(x)
    
    cropped_decoded = Cropping2D(cropping=((1, 1), (1, 1)))(decoded)
    
    autoencoder = Model(inputs, cropped_decoded)
    autoencoder.compile(optimizer = 'adam', loss = 'mean_squared_error')
    return autoencoder
In [78]:
# Build the autoencoder
autoencoder = build_autoencoder(input_shape)

# Train the autoencoder
autoencoder.fit(pixels_pipe.batch(16).prefetch(3),
                epochs = 3,
                shuffle = True,
                verbose = True,
                validation_data = pixels_validation.batch(16).prefetch(3))
Epoch 1/3
8610/8610 ━━━━━━━━━━━━━━━━━━━━ 806s 91ms/step - loss: 0.0029 - val_loss: 0.0012
Epoch 2/3
8610/8610 ━━━━━━━━━━━━━━━━━━━━ 764s 89ms/step - loss: 0.0012 - val_loss: 0.0011
Epoch 3/3
8610/8610 ━━━━━━━━━━━━━━━━━━━━ 762s 89ms/step - loss: 0.0011 - val_loss: 0.0010
Out[78]:
<keras.src.callbacks.history.History at 0x78595c67bdf0>

I iterated through different model architectures, noting down how each was performing. Six different potential architectures were tested in the process.

In [79]:
train_loss = autoencoder.evaluate(pixels_pipe.batch(64).prefetch(3))
valid_loss = autoencoder.evaluate(pixels_validation.batch(64).prefetch(3))
model_comparison.append({"Model Name": "128, 128",
                         "Parameters": autoencoder.count_params(),
                         "Train MSE": train_loss,
                         "Validation MSE": valid_loss})
2153/2153 ━━━━━━━━━━━━━━━━━━━━ 246s 104ms/step - loss: 0.0010
266/266 ━━━━━━━━━━━━━━━━━━━━ 38s 143ms/step - loss: 0.0010
In [46]:
# Convert to DataFrame
df = pd.DataFrame(model_comparison)

# Create a figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (16, 6))

# Bar plot for Number of Parameters
ax1.bar(df['Model Name'], df['Parameters'], color = 'dodgerblue', edgecolor = 'black')
ax1.set_xlabel('Model')
ax1.set_ylabel('Number of Parameters')
ax1.set_title('Number of Parameters in Each Model')
ax1.grid(alpha = 0.5)

# Bar plot for Validation MSE
ax2.bar(df['Model Name'], df['Validation MSE'], color = 'darkviolet', edgecolor = 'black')
ax2.set_xlabel('Model')
ax2.set_ylabel('Validation MSE')
ax2.set_title('Validation MSE for Each Model')
ax2.grid(alpha = 0.5)

plt.tight_layout()
plt.show()
In [160]:
indices = [0, 21, 24, 30,
           0, 21, 24, 30,
           0, 21, 24, 30]
fig, axes = plt.subplots(3, 4, figsize = (9,7), subplot_kw = {"xticks": [], "yticks": []})

for i, index in enumerate(indices):
    ax = axes.flat[i]
    
    if i == 0:
        ax.text(-0.1, 0.5, "Target Images", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
    elif i == 4:
        ax.text(-0.1, 0.5, "Pixelated Images", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
    elif i == 8:
        ax.text(-0.1, 0.5, "Depixelated Images", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
    
    if i < 4:
        image_with_batch = pixels_validation.skip(index).take(1)
        image_with_batch = np.expand_dims(next(iter(image_with_batch))[1], axis = 0)
        ax.imshow((image_with_batch[0] * 255).astype(int))
    elif i < 8:
        image_with_batch = pixels_validation.skip(index).take(1)
        image_with_batch = np.expand_dims(next(iter(image_with_batch))[0], axis = 0)
        ax.imshow((image_with_batch[0] * 255).astype(int))
    else:   
        image_with_batch = pixels_validation.skip(index).take(1)
        image_with_batch = np.expand_dims(next(iter(image_with_batch))[0], axis = 0)
        reconstructed_image = autoencoder.predict(image_with_batch)
        ax.imshow((reconstructed_image[0] * 255).astype(int))     
    
plt.tight_layout()
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 23ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 23ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step

[7]. Train GAN on Motion Blurred Data¶

In [90]:
import tensorflow as tf
from tensorflow.keras import layers, models, losses, optimizers

# Define the generator network
def build_generator(input_shape):
    model = models.Sequential([
        layers.Input(shape = input_shape),
        layers.Conv2D(64, (4, 4), strides = (2, 2), padding = 'same', use_bias = False),
        layers.BatchNormalization(),
        layers.LeakyReLU(negative_slope = 0.2),
        layers.Conv2D(128, (4, 4), strides = (2, 2), padding = 'same', use_bias = False),
        layers.BatchNormalization(),
        layers.LeakyReLU(negative_slope = 0.2),
        layers.Conv2DTranspose(64, (4, 4), strides = (2, 2), padding = 'same', use_bias = False),
        layers.BatchNormalization(),
        layers.ReLU(),
        layers.Conv2DTranspose(3, (4, 4), strides = (2, 2), padding = 'same', activation = 'tanh'),
        layers.Cropping2D(cropping = ((1, 1), (1, 1)))
    ])
    return model

# Define the discriminator network
def build_discriminator(input_shape):
    model = models.Sequential([
        layers.Input(shape=input_shape),
        layers.Conv2D(32, (4, 4), strides = (2, 2), padding='same'),
        layers.LeakyReLU(negative_slope = 0.2),
        layers.Conv2D(64, (4, 4), strides = (2, 2), padding='same'),
        layers.BatchNormalization(),
        layers.LeakyReLU(negative_slope = 0.2),
        layers.Conv2D(128, (4, 4), strides = (2, 2), padding='same'),
        layers.BatchNormalization(),
        layers.LeakyReLU(negative_slope = 0.2),
        layers.Flatten(),
        layers.Dense(1)
    ])
    return model

# Define the generator loss function
def generator_loss(fake_output):
    return losses.mean_squared_error(tf.ones_like(fake_output), fake_output)

# Define the discriminator loss function
def discriminator_loss(real_output, fake_output):
    real_loss = losses.mean_squared_error(tf.ones_like(real_output), real_output)
    fake_loss = losses.mean_squared_error(tf.zeros_like(fake_output), fake_output)
    return real_loss + fake_loss

# Define the generator and discriminator
input_shape = (178, 178, 3)
generator = build_generator(input_shape)
discriminator = build_discriminator(input_shape)

# Define optimizers
discriminator_optimizer = optimizers.Adam()
discriminator_optimizer = optimizers.Adam()
In [91]:
checkpoint = tf.train.Checkpoint(generator_optimizer = generator_optimizer,
                                 discriminator_optimizer = discriminator_optimizer,
                                 generator = generator,
                                 discriminator = discriminator)

# Define the training loop
@tf.function
def train_step(images_blurred, images_clear):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(images_blurred, training = True)
        
        real_output = discriminator(images_clear, training = True)
        fake_output = discriminator(generated_images, training = True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    return gen_loss, disc_loss

epochs = 6

def train(dataset, epochs):
    for epoch in range(epochs):
        progress_bar = tqdm(dataset, desc=f'Epoch {epoch+1}/{epochs}', unit='batch')
        for batch in progress_bar:
            gen_loss, disc_loss = train_step(batch[0], batch[1])
            progress_bar.set_postfix({'Generator Loss ': np.mean(gen_loss.numpy()), 'Discriminator Loss ': np.mean(disc_loss.numpy())})

# Train the model
train(motion_pipe.batch(256).prefetch(3), epochs)
Epoch 1/6: 100%|██████████| 539/539 [09:11<00:00,  1.02s/batch, Generator Loss =0.884, Discriminator Loss =0.27] 
Epoch 2/6: 100%|██████████| 539/539 [09:03<00:00,  1.01s/batch, Generator Loss =0.631, Discriminator Loss =0.219]
Epoch 3/6: 100%|██████████| 539/539 [09:03<00:00,  1.01s/batch, Generator Loss =0.832, Discriminator Loss =0.262]
Epoch 4/6: 100%|██████████| 539/539 [09:03<00:00,  1.01s/batch, Generator Loss =0.827, Discriminator Loss =0.189]
Epoch 5/6: 100%|██████████| 539/539 [09:02<00:00,  1.01s/batch, Generator Loss =0.832, Discriminator Loss =0.184]
Epoch 6/6: 100%|██████████| 539/539 [09:03<00:00,  1.01s/batch, Generator Loss =0.962, Discriminator Loss =0.168]
In [99]:
indices = [17, 19, 27, 39,
           17, 19, 27, 39,
           17, 19, 27, 39]
fig, axes = plt.subplots(3, 4, figsize = (9,7), subplot_kw = {"xticks": [], "yticks": []})

for i, index in enumerate(indices):
    ax = axes.flat[i]
    
    if i == 0:
        ax.text(-0.1, 0.5, "Target Images", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
    elif i == 4:
        ax.text(-0.1, 0.5, "Motion Blur Images", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
    elif i == 8:
        ax.text(-0.1, 0.5, "Deblurred, Artifacts", fontsize=10, ha='center', va='center', rotation=90, transform=ax.transAxes)
    
    if i < 4:
        image_with_batch = motion_validation.skip(index).take(1)
        image_with_batch = np.expand_dims(next(iter(image_with_batch))[1], axis = 0)
        ax.imshow((image_with_batch[0] * 255).astype(int))
    elif i < 8:
        image_with_batch = motion_validation.skip(index).take(1)
        image_with_batch = np.expand_dims(next(iter(image_with_batch))[0], axis = 0)
        ax.imshow((image_with_batch[0] * 255).astype(int))
    else:   
        image_with_batch = motion_validation.skip(index).take(1)
        image_with_batch = np.expand_dims(next(iter(image_with_batch))[0], axis = 0)
        reconstructed_image = generator.predict(image_with_batch)
        ax.imshow((reconstructed_image[0] * 255).astype(int))     
    
plt.tight_layout()
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step